1   /*
2    * Copyright (C) 2011 The Guava Authors
3    *
4    * Licensed under the Apache License, Version 2.0 (the "License");
5    * you may not use this file except in compliance with the License.
6    * You may obtain a copy of the License at
7    *
8    * http://www.apache.org/licenses/LICENSE-2.0
9    *
10   * Unless required by applicable law or agreed to in writing, software
11   * distributed under the License is distributed on an "AS IS" BASIS,
12   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13   * See the License for the specific language governing permissions and
14   * limitations under the License.
15   */
16  
17  package com.google.common.math;
18  
19  import static com.google.common.base.Preconditions.checkArgument;
20  import static com.google.common.base.Preconditions.checkNotNull;
21  import static com.google.common.math.MathPreconditions.checkNonNegative;
22  import static com.google.common.math.MathPreconditions.checkPositive;
23  import static com.google.common.math.MathPreconditions.checkRoundingUnnecessary;
24  import static java.math.RoundingMode.CEILING;
25  import static java.math.RoundingMode.FLOOR;
26  import static java.math.RoundingMode.HALF_EVEN;
27  
28  import com.google.common.annotations.GwtCompatible;
29  import com.google.common.annotations.VisibleForTesting;
30  
31  import java.math.BigInteger;
32  import java.math.RoundingMode;
33  import java.util.ArrayList;
34  import java.util.List;
35  
36  /**
37   * A class for arithmetic on values of type {@code BigInteger}.
38   *
39   * <p>The implementations of many methods in this class are based on material from Henry S. Warren,
40   * Jr.'s <i>Hacker's Delight</i>, (Addison Wesley, 2002).
41   *
42   * <p>Similar functionality for {@code int} and for {@code long} can be found in
43   * {@link IntMath} and {@link LongMath} respectively.
44   *
45   * @author Louis Wasserman
46   * @since 11.0
47   */
48  @GwtCompatible(emulated = true)
49  public final class BigIntegerMath {
50    /**
51     * Returns {@code true} if {@code x} represents a power of two.
52     */
53    public static boolean isPowerOfTwo(BigInteger x) {
54      checkNotNull(x);
55      return x.signum() > 0 && x.getLowestSetBit() == x.bitLength() - 1;
56    }
57  
58    /**
59     * Returns the base-2 logarithm of {@code x}, rounded according to the specified rounding mode.
60     *
61     * @throws IllegalArgumentException if {@code x <= 0}
62     * @throws ArithmeticException if {@code mode} is {@link RoundingMode#UNNECESSARY} and {@code x}
63     *         is not a power of two
64     */
65    @SuppressWarnings("fallthrough")
66    // TODO(kevinb): remove after this warning is disabled globally
67    public static int log2(BigInteger x, RoundingMode mode) {
68      checkPositive("x", checkNotNull(x));
69      int logFloor = x.bitLength() - 1;
70      switch (mode) {
71        case UNNECESSARY:
72          checkRoundingUnnecessary(isPowerOfTwo(x)); // fall through
73        case DOWN:
74        case FLOOR:
75          return logFloor;
76  
77        case UP:
78        case CEILING:
79          return isPowerOfTwo(x) ? logFloor : logFloor + 1;
80  
81        case HALF_DOWN:
82        case HALF_UP:
83        case HALF_EVEN:
84          if (logFloor < SQRT2_PRECOMPUTE_THRESHOLD) {
85            BigInteger halfPower = SQRT2_PRECOMPUTED_BITS.shiftRight(
86                SQRT2_PRECOMPUTE_THRESHOLD - logFloor);
87            if (x.compareTo(halfPower) <= 0) {
88              return logFloor;
89            } else {
90              return logFloor + 1;
91            }
92          }
93          /*
94           * Since sqrt(2) is irrational, log2(x) - logFloor cannot be exactly 0.5
95           *
96           * To determine which side of logFloor.5 the logarithm is, we compare x^2 to 2^(2 *
97           * logFloor + 1).
98           */
99          BigInteger x2 = x.pow(2);
100         int logX2Floor = x2.bitLength() - 1;
101         return (logX2Floor < 2 * logFloor + 1) ? logFloor : logFloor + 1;
102 
103       default:
104         throw new AssertionError();
105     }
106   }
107 
108   /*
109    * The maximum number of bits in a square root for which we'll precompute an explicit half power
110    * of two. This can be any value, but higher values incur more class load time and linearly
111    * increasing memory consumption.
112    */
113   @VisibleForTesting static final int SQRT2_PRECOMPUTE_THRESHOLD = 256;
114 
115   @VisibleForTesting static final BigInteger SQRT2_PRECOMPUTED_BITS =
116       new BigInteger("16a09e667f3bcc908b2fb1366ea957d3e3adec17512775099da2f590b0667322a", 16);
117 
118   private static final double LN_10 = Math.log(10);
119   private static final double LN_2 = Math.log(2);
120 
121   /**
122    * Returns {@code n!}, that is, the product of the first {@code n} positive
123    * integers, or {@code 1} if {@code n == 0}.
124    *
125    * <p><b>Warning:</b> the result takes <i>O(n log n)</i> space, so use cautiously.
126    *
127    * <p>This uses an efficient binary recursive algorithm to compute the factorial
128    * with balanced multiplies.  It also removes all the 2s from the intermediate
129    * products (shifting them back in at the end).
130    *
131    * @throws IllegalArgumentException if {@code n < 0}
132    */
133   public static BigInteger factorial(int n) {
134     checkNonNegative("n", n);
135 
136     // If the factorial is small enough, just use LongMath to do it.
137     if (n < LongMath.factorials.length) {
138       return BigInteger.valueOf(LongMath.factorials[n]);
139     }
140 
141     // Pre-allocate space for our list of intermediate BigIntegers.
142     int approxSize = IntMath.divide(n * IntMath.log2(n, CEILING), Long.SIZE, CEILING);
143     ArrayList<BigInteger> bignums = new ArrayList<BigInteger>(approxSize);
144 
145     // Start from the pre-computed maximum long factorial.
146     int startingNumber = LongMath.factorials.length;
147     long product = LongMath.factorials[startingNumber - 1];
148     // Strip off 2s from this value.
149     int shift = Long.numberOfTrailingZeros(product);
150     product >>= shift;
151 
152     // Use floor(log2(num)) + 1 to prevent overflow of multiplication.
153     int productBits = LongMath.log2(product, FLOOR) + 1;
154     int bits = LongMath.log2(startingNumber, FLOOR) + 1;
155     // Check for the next power of two boundary, to save us a CLZ operation.
156     int nextPowerOfTwo = 1 << (bits - 1);
157 
158     // Iteratively multiply the longs as big as they can go.
159     for (long num = startingNumber; num <= n; num++) {
160       // Check to see if the floor(log2(num)) + 1 has changed.
161       if ((num & nextPowerOfTwo) != 0) {
162         nextPowerOfTwo <<= 1;
163         bits++;
164       }
165       // Get rid of the 2s in num.
166       int tz = Long.numberOfTrailingZeros(num);
167       long normalizedNum = num >> tz;
168       shift += tz;
169       // Adjust floor(log2(num)) + 1.
170       int normalizedBits = bits - tz;
171       // If it won't fit in a long, then we store off the intermediate product.
172       if (normalizedBits + productBits >= Long.SIZE) {
173         bignums.add(BigInteger.valueOf(product));
174         product = 1;
175         productBits = 0;
176       }
177       product *= normalizedNum;
178       productBits = LongMath.log2(product, FLOOR) + 1;
179     }
180     // Check for leftovers.
181     if (product > 1) {
182       bignums.add(BigInteger.valueOf(product));
183     }
184     // Efficiently multiply all the intermediate products together.
185     return listProduct(bignums).shiftLeft(shift);
186   }
187 
188   static BigInteger listProduct(List<BigInteger> nums) {
189     return listProduct(nums, 0, nums.size());
190   }
191 
192   static BigInteger listProduct(List<BigInteger> nums, int start, int end) {
193     switch (end - start) {
194       case 0:
195         return BigInteger.ONE;
196       case 1:
197         return nums.get(start);
198       case 2:
199         return nums.get(start).multiply(nums.get(start + 1));
200       case 3:
201         return nums.get(start).multiply(nums.get(start + 1)).multiply(nums.get(start + 2));
202       default:
203         // Otherwise, split the list in half and recursively do this.
204         int m = (end + start) >>> 1;
205         return listProduct(nums, start, m).multiply(listProduct(nums, m, end));
206     }
207   }
208 
209  /**
210    * Returns {@code n} choose {@code k}, also known as the binomial coefficient of {@code n} and
211    * {@code k}, that is, {@code n! / (k! (n - k)!)}.
212    *
213    * <p><b>Warning:</b> the result can take as much as <i>O(k log n)</i> space.
214    *
215    * @throws IllegalArgumentException if {@code n < 0}, {@code k < 0}, or {@code k > n}
216    */
217   public static BigInteger binomial(int n, int k) {
218     checkNonNegative("n", n);
219     checkNonNegative("k", k);
220     checkArgument(k <= n, "k (%s) > n (%s)", k, n);
221     if (k > (n >> 1)) {
222       k = n - k;
223     }
224     if (k < LongMath.biggestBinomials.length && n <= LongMath.biggestBinomials[k]) {
225       return BigInteger.valueOf(LongMath.binomial(n, k));
226     }
227 
228     BigInteger accum = BigInteger.ONE;
229 
230     long numeratorAccum = n;
231     long denominatorAccum = 1;
232 
233     int bits = LongMath.log2(n, RoundingMode.CEILING);
234 
235     int numeratorBits = bits;
236 
237     for (int i = 1; i < k; i++) {
238       int p = n - i;
239       int q = i + 1;
240 
241       // log2(p) >= bits - 1, because p >= n/2
242 
243       if (numeratorBits + bits >= Long.SIZE - 1) {
244         // The numerator is as big as it can get without risking overflow.
245         // Multiply numeratorAccum / denominatorAccum into accum.
246         accum = accum
247             .multiply(BigInteger.valueOf(numeratorAccum))
248             .divide(BigInteger.valueOf(denominatorAccum));
249         numeratorAccum = p;
250         denominatorAccum = q;
251         numeratorBits = bits;
252       } else {
253         // We can definitely multiply into the long accumulators without overflowing them.
254         numeratorAccum *= p;
255         denominatorAccum *= q;
256         numeratorBits += bits;
257       }
258     }
259     return accum
260         .multiply(BigInteger.valueOf(numeratorAccum))
261         .divide(BigInteger.valueOf(denominatorAccum));
262   }
263 
264   // Returns true if BigInteger.valueOf(x.longValue()).equals(x).
265 
266   private BigIntegerMath() {}
267 }
268